{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from config import get_config, get_weights_file_path\n",
    "from train import get_model, get_ds, run_validation, causal_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device: cuda\n",
      "Max length of source sentence: 309\n",
      "Max length of target sentence: 274\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Define the device\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(\"Using device:\", device)\n",
    "config = get_config()\n",
    "train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)\n",
    "model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)\n",
    "\n",
    "# Load the pretrained weights\n",
    "model_filename = get_weights_file_path(config, f\"19\")\n",
    "state = torch.load(model_filename)\n",
    "model.load_state_dict(state['model_state_dict'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "            SOURCE: Hence it is that for so long a time, and during so much fighting in the past twenty years, whenever there has been an army wholly Italian, it has always given a poor account of itself; the first witness to this is Il Taro, afterwards Allesandria, Capua, Genoa, Vaila, Bologna, Mestri.\n",
      "            TARGET: Di qui nasce che, in tanto tempo, in tante guerre fatte ne' passati venti anni, quando elli è stato uno esercito tutto italiano, sempre ha fatto mala pruova. Di che è testimone prima el Taro, di poi Alessandria, Capua, Genova, Vailà, Bologna, Mestri.\n",
      "  PREDICTED GREEDY: Di qui nasce che , in tanto , in tanto tempo , in tante guerre fatte ne ' passati\n",
      "    PREDICTED BEAM: Di qui nasce che , in tanto tempo , in tante guerre fatte ne ' passati venti anni ,\n",
      "--------------------------------------------------------------------------------\n",
      "            SOURCE: She went out.\n",
      "            TARGET: Aprì lo sportello e venne fuori.\n",
      "  PREDICTED GREEDY: Aprì lo sportello e venne fuori .\n",
      "    PREDICTED BEAM: Aprì lo sportello e venne fuori . — Ecco , poi uscì e andò via . — Ecco ,\n",
      "--------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "def beam_search_decode(model, beam_size, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):\n",
    "    sos_idx = tokenizer_tgt.token_to_id('[SOS]')\n",
    "    eos_idx = tokenizer_tgt.token_to_id('[EOS]')\n",
    "\n",
    "    # Precompute the encoder output and reuse it for every step\n",
    "    encoder_output = model.encode(source, source_mask)\n",
    "    # Initialize the decoder input with the sos token\n",
    "    decoder_initial_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)\n",
    "\n",
    "    # Create a candidate list\n",
    "    candidates = [(decoder_initial_input, 1)]\n",
    "\n",
    "    while True:\n",
    "\n",
    "        # If a candidate has reached the maximum length, it means we have run the decoding for at least max_len iterations, so stop the search\n",
    "        if any([cand.size(1) == max_len for cand, _ in candidates]):\n",
    "            break\n",
    "\n",
    "        # Create a new list of candidates\n",
    "        new_candidates = []\n",
    "\n",
    "        for candidate, score in candidates:\n",
    "\n",
    "            # Do not expand candidates that have reached the eos token\n",
    "            if candidate[0][-1].item() == eos_idx:\n",
    "                continue\n",
    "\n",
    "            # Build the candidate's mask\n",
    "            candidate_mask = causal_mask(candidate.size(1)).type_as(source_mask).to(device)\n",
    "            # calculate output\n",
    "            out = model.decode(encoder_output, source_mask, candidate, candidate_mask)\n",
    "            # get next token probabilities\n",
    "            prob = model.project(out[:, -1])\n",
    "            # get the top k candidates\n",
    "            topk_prob, topk_idx = torch.topk(prob, beam_size, dim=1)\n",
    "            for i in range(beam_size):\n",
    "                # for each of the top k candidates, get the token and its probability\n",
    "                token = topk_idx[0][i].unsqueeze(0).unsqueeze(0)\n",
    "                token_prob = topk_prob[0][i].item()\n",
    "                # create a new candidate by appending the token to the current candidate\n",
    "                new_candidate = torch.cat([candidate, token], dim=1)\n",
    "                # We sum the log probabilities because the probabilities are in log space\n",
    "                new_candidates.append((new_candidate, score + token_prob))\n",
    "\n",
    "        # Sort the new candidates by their score\n",
    "        candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)\n",
    "        # Keep only the top k candidates\n",
    "        candidates = candidates[:beam_size]\n",
    "\n",
    "        # If all the candidates have reached the eos token, stop\n",
    "        if all([cand[0][-1].item() == eos_idx for cand, _ in candidates]):\n",
    "            break\n",
    "\n",
    "    # Return the best candidate\n",
    "    return candidates[0][0].squeeze()\n",
    "\n",
    "def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):\n",
    "    sos_idx = tokenizer_tgt.token_to_id('[SOS]')\n",
    "    eos_idx = tokenizer_tgt.token_to_id('[EOS]')\n",
    "\n",
    "    # Precompute the encoder output and reuse it for every step\n",
    "    encoder_output = model.encode(source, source_mask)\n",
    "    # Initialize the decoder input with the sos token\n",
    "    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)\n",
    "    while True:\n",
    "        if decoder_input.size(1) == max_len:\n",
    "            break\n",
    "\n",
    "        # build mask for target\n",
    "        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)\n",
    "\n",
    "        # calculate output\n",
    "        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)\n",
    "\n",
    "        # get next token\n",
    "        prob = model.project(out[:, -1])\n",
    "        _, next_word = torch.max(prob, dim=1)\n",
    "        decoder_input = torch.cat(\n",
    "            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1\n",
    "        )\n",
    "\n",
    "        if next_word == eos_idx:\n",
    "            break\n",
    "\n",
    "    return decoder_input.squeeze(0)\n",
    "\n",
    "def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, num_examples=2):\n",
    "    model.eval()\n",
    "    count = 0\n",
    "\n",
    "    console_width = 80\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for batch in validation_ds:\n",
    "            count += 1\n",
    "            encoder_input = batch[\"encoder_input\"].to(device) # (b, seq_len)\n",
    "            encoder_mask = batch[\"encoder_mask\"].to(device) # (b, 1, 1, seq_len)\n",
    "\n",
    "            # check that the batch size is 1\n",
    "            assert encoder_input.size(\n",
    "                0) == 1, \"Batch size must be 1 for validation\"\n",
    "\n",
    "            \n",
    "            model_out_greedy = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)\n",
    "            model_out_beam = beam_search_decode(model, 3, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)\n",
    "\n",
    "            source_text = batch[\"src_text\"][0]\n",
    "            target_text = batch[\"tgt_text\"][0]\n",
    "            model_out_text_beam = tokenizer_tgt.decode(model_out_beam.detach().cpu().numpy())\n",
    "            model_out_text_greedy = tokenizer_tgt.decode(model_out_greedy.detach().cpu().numpy())\n",
    "            \n",
    "            # Print the source, target and model output\n",
    "            print_msg('-'*console_width)\n",
    "            print_msg(f\"{f'SOURCE: ':>20}{source_text}\")\n",
    "            print_msg(f\"{f'TARGET: ':>20}{target_text}\")\n",
    "            print_msg(f\"{f'PREDICTED GREEDY: ':>20}{model_out_text_greedy}\")\n",
    "            print_msg(f\"{f'PREDICTED BEAM: ':>20}{model_out_text_beam}\")\n",
    "\n",
    "            if count == num_examples:\n",
    "                print_msg('-'*console_width)\n",
    "                break\n",
    "\n",
    "run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, 20, device, print_msg=print, num_examples=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "transformer",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
